
import asyncio
import typing
import logging
import pytz

from apscheduler.job import Job
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger

from wintersweet.asyncs.interface import TaskInterface
from wintersweet.utils.base import Utils

TIMEZONE = pytz.timezone(Utils.os.getenv('_TIMEZONE', r'Asia/Shanghai'))


logging.getLogger(r'apscheduler').setLevel(logging.WARNING)


class BaseAsyncIOScheduler(AsyncIOScheduler):
    def __init__(self, tag, **kwargs):
        self._tag = tag
        super().__init__(**kwargs)

    @property
    def tag(self):
        return self._tag


class AsyncIOSchedulerManager:
    """线程安全的定时任务管理器"""

    _schedulers = {}

    @classmethod
    def _create_scheduler(cls, tag='default'):

        if tag in cls._schedulers:
            return cls._schedulers[tag]

        scheduler = BaseAsyncIOScheduler(
            tag=tag,
            job_defaults={
                r'coalesce': False,
                r'max_instances': 1,
                r'misfire_grace_time': 10
            },
            timezone=TIMEZONE
        )
        result = cls._schedulers.setdefault(tag, scheduler)
        if id(result) == id(scheduler):
            result.start()

        return result

    @classmethod
    def add_job(cls, tag: str, **job_kwargs) -> Job:
        scheduler = cls._create_scheduler(tag)
        return scheduler.add_job(**job_kwargs)

    @classmethod
    def remove_job(cls, tag: str, job: Job):
        if tag in cls._schedulers:
            cls._schedulers[tag].remove_job(job.id)

    @classmethod
    def stop(cls):
        _schedulers, cls._schedulers = cls._schedulers, {}
        for scheduler in _schedulers.values():
            scheduler.shutdown()

        cls._schedulers.clear()


class BaseTask(TaskInterface):

    def __init__(self, func, tag='default', *args, **kwargs):
        super(BaseTask, self).__init__()
        self._tag = tag
        self._func = Utils.package_async_func(func, *args, **kwargs)
        self._job = None

    def start(self):
        if not self._job:
            self._job = AsyncIOSchedulerManager.add_job(self._tag, **self._job_args())
            self._running = True
        return self._job

    def stop(self):
        assert self._job is not None, f'{self.__class__.__name__} has not start'
        AsyncIOSchedulerManager.remove_job(self._tag, self._job)
        self._job = None
        self._running = False
        return True

    def _job_args(self):

        raise InterruptedError


class IntervalTask(BaseTask):
    """间隔触发任务"""

    def __init__(self, interval: int, func, tag=None, *args, **kwargs):

        super(IntervalTask, self).__init__(func, tag or self.__class__.__name__, *args, **kwargs)

        self._interval = interval

    def _job_args(self):
        return {
            'func': self._func,
            'trigger': r'interval',
            'seconds': self._interval
        }


class CrontabTask(BaseTask):
    """定时触发任务"""
    def __init__(self, crontab, func, tag=None, *args, **kwargs):

        self._crontab = crontab
        self._tz = kwargs.pop('TIMEZONE') if 'TIMEZONE' in kwargs else TIMEZONE

        super(CrontabTask, self).__init__(func, tag or self.__class__.__name__, *args, **kwargs)

    def _job_args(self):
        return {
            'func': self._func,
            'trigger': CronTrigger.from_crontab(self._crontab, self._tz),
        }


class TaskManager:

    def __init__(self):
        self._interval_tasks = []
        self._crontab_tasks = []

    def add_interval_task(self, interval: int, func, tag='IntervalTask', *args, **kwargs):

        task = IntervalTask(interval, func, tag, *args, **kwargs)
        self._interval_tasks.append(task)
        return task

    def add_crontab_task(self, crontab, func, tag='CrontabTask', *args, **kwargs):

        task = CrontabTask(crontab, func, tag, *args, **kwargs)
        self._crontab_tasks.append(task)
        return task

    async def register(self):

        for task in self._interval_tasks:
            task.start()

        for task in self._crontab_tasks:
            task.start()

    async def shutdown(self):

        for task in self._interval_tasks:
            task.stop()

        for task in self._crontab_tasks:
            task.stop()

        self._interval_tasks.clear()
        self._crontab_tasks.clear()


class ParallelTask(asyncio.Future):
    """并行任务调度器，可控制并行数量
        其中完成一个任务会自动调度一个新任务继续执行，直到所有任务调度执行完成。
        所有任务完成后顺序返回所有任务执行结果
       适用场景：
        大量任务需要有节制(需控制并发量)的并行，不希望被某一个任务执行时间长短影响并行进度时
    Usage::
      >>> tasks = ParallelTask(5)
      >>> for i in range(20): tasks.append(asyncio.sleep(i))
      >>> await tasks
      >>> print(tasks.results)
    """
    def __init__(self, max_size):
        super(ParallelTask, self).__init__()
        self._running_size = max_size
        self._results = []
        self._running_tasks = 0
        self._wait_tasks = []

    @property
    def results(self):
        return self._results

    def append(self, task: typing.Awaitable):
        self._wait_tasks.append(task)

    def extend(self, tasks: typing.List[typing.Awaitable]):
        self._wait_tasks.extend(tasks)

    def __await__(self):

        self._running_tasks = 0
        self._results.clear()
        if not self._wait_tasks:
            yield
            self.set_result(self._results.copy())
        else:
            while self._running_tasks < self._running_size and self._wait_tasks:
                self._running_tasks += 1
                task = self._wait_tasks.pop(0)
                t = asyncio.create_task(task)
                t.add_done_callback(self.__run_wait_task)
                self._results.append(t)

            yield from super().__await__()

    def __run_wait_task(self, *args):
        if self._wait_tasks:
            task = self._wait_tasks.pop(0)
            t = asyncio.create_task(task)
            t.add_done_callback(self.__run_wait_task)

            self._results.append(t)
        else:
            self._running_tasks -= 1
            if self._running_tasks == 0:
                try:
                    self._results = [item.result() for item in self._results]
                    self.set_result(self._results.copy())
                except Exception as e:
                    self.set_exception(e)

    def __iter__(self):

        for item in self._results:
            yield item

